load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test")
load("//psi/algorithm/spiral:copts.bzl", "spiral_copts")

package(default_visibility = ["//visibility:public"])

psi_cc_library(
    name = "arith",
    hdrs = ["arith.h"],
    deps = [
        "//psi/algorithm/spiral:common",
        "@abseil-cpp//absl/strings",
        "@seal",
        "@yacl//yacl/base:exception",
        "@yacl//yacl/base:int128",
        "@yacl//yacl/math:gadget",
    ],
)

psi_cc_library(
    name = "arith_params",
    hdrs = ["arith_params.h"],
    deps = [
        ":arith",
        "//psi/algorithm/spiral:common",
        "//psi/algorithm/spiral:params",
        "@abseil-cpp//absl/strings",
        "@seal",
        "@yacl//yacl/base:exception",
        "@yacl//yacl/base:int128",
    ],
)

psi_cc_library(
    name = "number_theory",
    hdrs = ["number_theory.h"],
    deps = [
        ":arith",
        "//psi/algorithm/spiral:common",
        "@abseil-cpp//absl/strings",
        "@seal",
        "@yacl//yacl/base:exception",
    ],
)

psi_cc_library(
    name = "ntt_table",
    srcs = ["ntt_table.cc"],
    hdrs = ["ntt_table.h"],
    deps = [
        ":arith",
        ":number_theory",
        "//psi/algorithm/spiral:common",
        "@seal",
        "@yacl//yacl/base:exception",
    ],
)

psi_cc_library(
    name = "ntt",
    srcs = ["ntt.cc"],
    hdrs = ["ntt.h"],
    copts = spiral_copts(),
    deps = [
        ":arith",
        ":ntt_table",
        ":number_theory",
        "//psi/algorithm/spiral:params",
        "@abseil-cpp//absl/types:span",
        "@seal",
        "@yacl//yacl/base:aligned_vector",
        "@yacl//yacl/base:exception",
        "@yacl//yacl/utils:parallel",
    ] + select({
        "@platforms//cpu:aarch64": [
            "@sse2neon",
        ],
        "//conditions:default": [],
    }),
)

psi_cc_test(
    name = "arith_test",
    srcs = ["arith_test.cc"],
    deps = [
        ":arith",
        ":arith_params",
        "//psi/algorithm/spiral:util",
        "@abseil-cpp//absl/strings",
        "@seal",
    ],
)

psi_cc_test(
    name = "number_theory_test",
    srcs = ["number_theory_test.cc"],
    deps = [
        ":number_theory",
    ],
)

psi_cc_test(
    name = "ntt_table_test",
    srcs = ["ntt_table_test.cc"],
    copts = spiral_copts(),
    deps = [
        ":ntt",
        ":ntt_table",
        "//psi/algorithm/spiral:params",
        "//psi/algorithm/spiral:util",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/base:aligned_vector",
    ],
)
